from asyncio.unix_events import BaseChildWatcher
from genericpath import isfile
from random import uniform
from tabnanny import verbose
# from types import NotImplementedType
import numpy as np

from utils import lr_search_max_loss_diff_pt, svm_search_max_loss_diff_pt, sigmoid, compute_loss, get_grads,\
    margin_model_search_max_loss_pt_contin, get_grads, train_model, test_model

import matplotlib.pyplot as plt
import time
from scipy.optimize import fmin_ncg
import scipy.sparse as sparse

import os
import math
import pickle 

from data_utils import get_data_params
import kkt_attack 

############################################################################################ 
### attacks below are built on top of the attacks proposed in the original repositories of 
### Koh et al., (2022) and Suya et el. (2021). 
############################################################################################ 

def kkt_attacks(args,x_lims,total_epsilon, clean_data, tar_model, epsilon_increment=0.005, verbose=False):
    target_theta,target_bias = tar_model.coef_.reshape(-1), tar_model.intercept_[0]

    x_min,x_max = x_lims
    # use the feature constraints consistently
    x_pos_tuple = x_lims
    x_neg_tuple = x_lims

    X_train,Y_train,X_test,Y_test = clean_data

    if sparse.issparse(X_train):
        X_train_cp = sparse.csr_matrix.copy(X_train)
    else:
        X_train_cp = np.copy(X_train)
    Y_train_cp = np.copy(Y_train)

    dataset_name = args.dataset
    # some params not needed for this experiments
    percentile = loss_percentile = int(np.round(float(args.percentile))) # 0.1
    class_map, centroids, centroid_vec, sphere_radii, slab_radii = get_data_params(X_train, Y_train, percentile)
    # setup the kkt attack solver
    model_tmp = train_model(X_train,Y_train,args)
    two_class_kkt, clean_grad_at_target_theta, target_bias_grad, max_losses = kkt_attack.kkt_setup(
        target_theta,
        target_bias,
        X_train_cp, Y_train_cp,
        X_test, Y_test,
        dataset_name,
        percentile,
        loss_percentile,
        model_tmp,
        model_grad=get_grads,
        class_map= class_map,
        use_slab=args.use_slab,
        use_loss=False,
        use_l2=args.use_sphere,
        x_pos_tuple=x_pos_tuple,
        x_neg_tuple=x_neg_tuple,
        model_type=args.model_type)
    ## enumarate all the possible pos and neg epsilons
    epsilon_pairs = []
    epsilon_neg = (total_epsilon - target_bias_grad) / 2
    epsilon_pos = total_epsilon - epsilon_neg
    if (epsilon_neg >= 0) and (epsilon_neg <= total_epsilon):
        epsilon_pairs.append((epsilon_pos, epsilon_neg))
    for epsilon_pos in np.arange(0, total_epsilon + 1e-6, epsilon_increment):
        epsilon_neg = total_epsilon - epsilon_pos
        epsilon_pairs.append((epsilon_pos, epsilon_neg))
    ## run the kkt attack for all generated epsilon pairs
    target_grad = clean_grad_at_target_theta + ((1 + total_epsilon) * args.weight_decay * target_theta)
    best_test_error = -1e10
    # store the intermediate models to calculate the upper bound of interest
    all_thetas = np.zeros((len(epsilon_pairs), X_train.shape[1]+1))
    train_01_losses = []
    test_01_losses = []
    train_victim_losses = []
    test_victim_losses = []

    for i in range(len(epsilon_pairs)):
        epsilon_pos, epsilon_neg = epsilon_pairs[i]
        if verbose:
            print('\n## Trying epsilon_pos %s, epsilon_neg %s' % (epsilon_pos, epsilon_neg))
        if args.model_type == 'svm':
            X_modified, Y_modified, obj, x_pos, x_neg, num_pos, num_neg = kkt_attack.kkt_attack(
                two_class_kkt,
                target_grad, target_theta,
                total_epsilon, [epsilon_pos], [epsilon_neg],
                X_train_cp, Y_train_cp,
                class_map, centroids, centroid_vec, sphere_radii, slab_radii,
                [target_bias], target_bias_grad, max_losses)
        elif args.model_type == 'lr':
            data_dim = X_train.shape[1]
            X_modified, Y_modified, obj, x_pos, x_neg, num_pos, num_neg = kkt_attack.kkt_for_lr(
                data_dim,args,target_grad,target_theta,target_bias,
                total_epsilon, epsilon_pos, epsilon_neg, X_train, Y_train, x_pos_tuple,x_neg_tuple,
                lr=1e-1,num_steps=300,trials=2,optimizer='adam',verbose=False)
        # unique x and y for kkt attack
        idx_poison = slice(X_train.shape[0], X_modified.shape[0])
        X_poison = X_modified[idx_poison,:]
        Y_poison = Y_modified[idx_poison]
        # train and test induced model performance, and pick best one
        model_p = train_model(X_modified,Y_modified,args)
        total_train_loss, poison_train_loss, clean_train_loss, clean_test_loss, total_train_acc, \
        poison_train_acc, clean_train_acc, clean_test_acc = test_model(X_modified,\
                 Y_modified,X_train,Y_train,X_test,Y_test,model_p,args,verbose=False)
        if best_test_error < 1-clean_test_acc:
            best_test_error = 1-clean_test_acc
            best_train_error = 1-clean_train_acc
            best_train_loss = clean_train_loss
            best_test_loss = clean_test_loss
            best_x = X_poison
            best_y = Y_poison
        
        all_thetas[i,:-1] = model_p.coef_.reshape(-1)
        all_thetas[i,-1] = model_p.intercept_[0]
        train_01_losses.append(1-clean_train_acc)
        test_01_losses.append(1-clean_test_acc)
        train_victim_losses.append(clean_train_loss)
        test_victim_losses.append(clean_test_loss)

    train_01_losses = np.array(train_01_losses)
    train_victim_losses = np.array(train_victim_losses)
    test_01_losses = np.array(test_01_losses)
    test_victim_losses = np.array(test_victim_losses)
    return [best_x, best_y, all_thetas,train_victim_losses,train_01_losses,test_victim_losses,test_01_losses],\
    [best_train_loss,best_train_error,best_test_loss,best_test_error]

def mta_attack(args,curr_model,tar_model,x_lims,num_iter,clean_data,use_slab=False,use_sphere=False,defense_pars=None):
    x_min, x_max = x_lims
    X_train,Y_train,X_test,Y_test = clean_data
    total_train_loss, poison_train_loss, clean_train_loss, clean_test_loss, total_train_acc, poison_train_acc, clean_train_acc, clean_test_acc = test_model(X_train,\
        Y_train,X_train,Y_train,X_test,Y_test,curr_model,args,verbose=False)
    print("Clean Model Info: Train Error {}, Train Loss {}, Test Error {}, Test Loss {}".format(1-clean_train_acc,clean_train_loss,\
        1-clean_test_acc,clean_test_loss))

    theta_curr = curr_model.coef_.reshape(-1)
    bias_curr = curr_model.intercept_[0]
    theta_tar = tar_model.coef_.reshape(-1)
    bias_tar = tar_model.intercept_[0]

    classes = [-1,1]
    train_01_losses = np.zeros(num_iter)
    train_victim_losses = np.zeros(num_iter)
    test_01_losses = np.zeros(num_iter)
    test_victim_losses = np.zeros(num_iter)
    loss_on_curr_model = np.zeros(num_iter) 
    loss_on_tar_model = np.zeros(num_iter)
    # max_loss_diff = np.zeros(num_iter)

    all_thetas = np.zeros((num_iter,X_train.shape[1]+1))
    if sparse.issparse(X_train):
        X_train_p = sparse.csr_matrix.copy(X_train)
    else:
        X_train_p = np.copy(X_train)
    Y_train_p = np.copy(Y_train)

    if args.model_type == 'lr':
        if args.dataset == '2d_toy':
            lr = 0.01
            num_steps = 300
            trials = 2
        elif args.dataset in ['enron','imdb']:
            lr = 0.5
            num_steps = 2000
            trials = 10
            tol_param = 1e-4
        else:
            tol_param = 1e-7
            lr = 0.1
            num_steps = 20000
            trials = 10
    for j in range(num_iter):
        best_max_loss_diff = -1
        for cls in classes: 
            # max_loss, max_x = svm_search_max_loss_diff_pt(curr_model,tar_model,cls,x_lims,args,verbose=False)
            if args.model_type == 'lr':
                max_loss_diff, max_loss_diff_real, max_x = lr_search_max_loss_diff_pt(X_train.shape[1],curr_model,tar_model,cls,x_lims,\
                args,lr=lr,num_steps=num_steps,trials=trials,optimizer = 'adam',verbose=False,candidate_set=X_train_p)
            elif args.model_type == 'svm':
                max_loss_diff, max_x = svm_search_max_loss_diff_pt(curr_model,tar_model,cls,x_lims,args,verbose=False,\
                                                                   use_slab=use_slab,use_sphere=use_sphere,defense_pars=defense_pars)
            # use_slab=False,use_sphere=False,defense_pars=defense_pars
            # use_slab=use_slab,use_sphere=use_sphere,defense_pars=defense_pars
            if best_max_loss_diff < max_loss_diff:
                best_max_loss_diff = max_loss_diff
                best_y = cls
                best_x = max_x
        
        if sparse.issparse(X_train):
            best_x = sparse.csr_matrix([best_x])
        else:
            best_x = np.array([best_x])
        best_y = np.array([best_y])

        loss_on_curr = np.mean(compute_loss(args.model_type,best_x,best_y,theta_curr,bias_curr))
        loss_on_tar = np.mean(compute_loss(args.model_type,best_x,best_y,theta_tar,bias_tar))
        loss_on_curr_model[j] = loss_on_curr
        loss_on_tar_model[j] = loss_on_tar

        if sparse.issparse(X_train):
            X_train_p = sparse.vstack((X_train_p, best_x), format='csr')   
        else:
            X_train_p = np.concatenate((X_train_p,best_x),axis=0)
        Y_train_p = np.concatenate((Y_train_p,best_y),axis=0)

        curr_model = train_model(X_train_p,Y_train_p,args)
        # test the updated model
        total_train_loss, poison_train_loss, clean_train_loss, clean_test_loss, total_train_acc, poison_train_acc, clean_train_acc, clean_test_acc = test_model(X_train_p,\
            Y_train_p,X_train,Y_train,X_test,Y_test,curr_model,args,verbose=False)
        train_01_losses[j] = 1-clean_train_acc
        train_victim_losses[j] = clean_train_loss
        test_01_losses[j] = 1-clean_test_acc
        test_victim_losses[j] = clean_test_loss 

        if j % args.print_every == 0:
            print("Iter {} Attack Info: Train Error {:.6f}, Train Loss {:.6f}, Test Error {:.6f}, Test Loss {:.6f}, Max Loss Diff: {:.6f}".format(j,\
                1-clean_train_acc,clean_train_loss,1-clean_test_acc,clean_test_loss,best_max_loss_diff))

        # record the intermediate model weights
        theta_curr = curr_model.coef_.reshape(-1)
        bias_curr = curr_model.intercept_[0]
        all_thetas[j,:-1] = theta_curr
        all_thetas[j,-1] = bias_curr

    # record the generated poison points
    X_poison = X_train_p[X_train.shape[0]:,:]
    Y_poison = Y_train_p[X_train.shape[0]:]

    return [X_poison,Y_poison,all_thetas,train_victim_losses,train_01_losses,test_victim_losses,test_01_losses], [loss_on_curr_model,loss_on_tar_model]

def min_max_attack(epsilon,args,x_lims,num_iter,clean_data,num_sgd_steps=1,min_max_lr_matlab=0.03,burn_frac = 0.0,tar_model=None,C=None,\
                        use_slab=False,use_sphere=False,defense_pars=None):
    # num_sgd_steps: # of iterations to update the model
    # min max attack is not impacted by whether the test data is used during attack optimization process
    x_min,x_max = x_lims
    X_train,Y_train,X_test,Y_test = clean_data
    
    clean_model = train_model(X_train,Y_train,args)
    total_train_loss, poison_train_loss, clean_train_loss, clean_test_loss, total_train_acc, poison_train_acc, clean_train_acc, clean_test_acc = test_model(X_train,\
        Y_train,X_train,Y_train,X_test,Y_test,clean_model,args,verbose=False)
    curr_theta, curr_bias = clean_model.coef_.reshape(-1), clean_model.intercept_[0] 
    print("Clean Model Info: Train Error {}, Train Loss {}, Test Error {}, Test Loss {}".format(1-clean_train_acc,clean_train_loss,\
        1-clean_test_acc,clean_test_loss))

    if sparse.issparse(X_train):
        X_train_p = sparse.csr_matrix.copy(X_train)
    else:
        X_train_p = np.copy(X_train)
    Y_train_p = np.copy(Y_train)

    # some attack related params
    classes = [-1,1]
    g_pow = 1 # these values are based on an assumption that, clean model trainining is solved exactly, instead of greedy opt strategies
    g_bias_pow = 1
    # tol_par = 1e-8
    repeat_num = 1 # following the default setting in matlab version
    n_burn = round(burn_frac*num_iter)
    n_att = num_iter
    max_iter = n_burn + n_att

    train_01_losses = []
    train_victim_losses = []
    test_01_losses = []
    test_victim_losses = []
    all_thetas = np.zeros((max_iter,X_train.shape[1]+1))
    # the attack iteration, this attack needs to consider running the attack for additional iterations
    
    d = X_train.shape[1]
    x_lim_tuples = (x_lims,x_lims)
    if args.dataset == 'mnist_17':
        lr = 0.1
        num_steps = 3000
    elif args.dataset == 'dogfish':
        lr = 0.1
        num_steps = 3000 
    else:
        lr = 0.3
        num_steps = 1000

    if tar_model is not None:
        theta_tar = tar_model.coef_.reshape(-1)
        bias_tar = tar_model.intercept_[0]
    else:
        theta_tar = bias_tar = None

    for i in range(max_iter):
        max_loss, max_loss_real, best_x, best_y = margin_model_search_max_loss_pt_contin(False,args.model_type,curr_theta,np.array([curr_bias]),classes,x_lim_tuples,\
            args,theta_tar=theta_tar,bias_tar=np.array([bias_tar]), C=C,use_slab=use_slab,use_sphere=use_sphere,defense_pars=defense_pars)
        if args.model_type == 'svm' or args.dataset != 'adult':
            if i == 0:
                print("[Warning!] Recommend running matlab version for SVM or other datasets!")

        # update the train data, only used for recording the results of induced model
        if i > n_burn-1:
            if sparse.issparse(X_train):
                best_x = sparse.csr_matrix(best_x)
                X_train_p = sparse.vstack((X_train_p,best_x), format='csr')
            else:
                best_x = np.array(best_x)
                X_train_p = np.concatenate((X_train_p,best_x),axis=0)
            Y_train_p = np.concatenate((Y_train_p,np.array([best_y])),axis=0)
            poison_model = train_model(X_train_p,Y_train_p,args)   
            total_train_loss, poison_train_loss, clean_train_loss, clean_test_loss, total_train_acc, poison_train_acc, clean_train_acc, clean_test_acc = test_model(X_train_p,\
                Y_train_p,X_train,Y_train,X_test,Y_test,poison_model,args,verbose=False)
        # update the model weight using generated poisoning point

        for _ in range(num_sgd_steps):
            # because the lr gradient needs to be in {-1,1}
            # if sparse.issparse(X_train):
            #     X_train_ = X_train.toarray()
            # else:
            #     X_train_ = X_train
            # g_c, g_bias_c = get_grads(args.model_type,curr_theta,curr_bias,X_train,(Y_train+1)/2)
            # g_p, g_bias_p = get_grads(args.model_type,curr_theta,curr_bias,best_x,np.array([(best_y+1)/2]))
            g_c, g_bias_c = get_grads(args.model_type,curr_theta,curr_bias,X_train,Y_train)
            g_p, g_bias_p = get_grads(args.model_type,curr_theta,curr_bias,best_x, np.array([best_y]))

            g = g_c + epsilon * g_p
            g = g + args.weight_decay * curr_theta # add gradient of regularization term
            g_pow = g_pow + g**2
            g_bias = g_bias_c + epsilon * g_bias_p
            # g_bias = g_bias # no need to regularize bias term 
            g_bias_pow = g_bias_pow + g_bias**2
            # update model weights
            curr_theta = curr_theta - min_max_lr_matlab * g / (np.sqrt(g_pow)) 
            curr_bias = curr_bias - min_max_lr_matlab * g_bias / (np.sqrt(g_bias_pow))

        # record the intermediate model weights
        all_thetas[i,:-1] = curr_theta
        all_thetas[i,-1] = curr_bias

        if i > n_burn-1:
            train_01_losses.append(1-clean_train_acc)
            train_victim_losses.append(clean_train_loss)
            test_01_losses.append(1-clean_test_acc)
            test_victim_losses.append(clean_test_loss)

            if i % args.print_every == 0:
                print("Burn_In: {}, Iter {} Attack Info: Test Error {}, Test Loss {}, Max Loss: {:.5f}".format(n_burn,i-n_burn,1-clean_test_acc,clean_test_loss,\
                max_loss))

    train_01_losses = np.array(train_01_losses)
    train_victim_losses = np.array(train_victim_losses)
    test_01_losses = np.array(test_01_losses)
    test_victim_losses = np.array(test_victim_losses)    
    # note that the original min-max attack may run for additional iterations
    return [X_train_p[X_train.shape[0]:,:],Y_train_p[X_train.shape[0]:],all_thetas,\
        train_victim_losses,train_01_losses,test_victim_losses,test_01_losses]
